Jianbin Chang, Hongbin Liu
全分片数据并行(FSDP)简介
- FSDP 工作原理:ZeRO-DP 分片策略
- FSDP 工作原理:FSDP 单元 (FSDP-Unit
- FSDP 工作原理:AllGather
- FSDP 工作原理:Reduce-Scatter
- FSDP 工作原理:不同 ZeRO-DP 策略中的 AllGather (AG) 与 Reduce-Scatter (RS
自定义 FSDP - 关键特性
- Megatron-Core 自定义 FSDP 关键特性架构图
- 并行通信对 Gemm 的不利影响
- 基于 UBR 的 SM 高效 NCCL 通信
- 基于 UBR 的 SM 高效 NCCL 通信:AllGather 示例
- NCCL 组调用 (Group Calls
理论分析:FSDP 中的通信隐藏
- FSDP 中的计算与通信成本:Transformer 层前向传播
- FSDP 中的计算与通信成本:Transformer 层后向传播
- FSDP 中的计算与通信成本:MoE 层后向传播
- FSDP 中的计算与通信成本:调优并行策略
基准测试与案例研究
- 来自 BioNeMo 的自定义 FSDP 反馈
- 基准测试:vs. FSDP2
- 基准测试:vs. 3D 并行
- FP8 训练中的挑战:CPU 瓶颈问题及 FSDP 解决方案
- 案例研究:Llama3-70B + FSDP
什么是 FSDP?
FSDP 的优势
FSDP 的工作原理
参考 PyTorch FSDP 介绍:https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html
FSDP 的分片策略能够显著减少每个 GPU 的内存消耗,但更激进的分片会增加通信开销。
(2 + 2 + K) * Ψ,在此示例中为 120GB。通信开销为 1 次 all-reduce。FSDP 单元是我们用于取消分片/重新分片参数的操作单元,确保这些参数具有相同的计算生命周期。
gather full params 操作收集完整的参数。计算后,通过 free peer shards 释放从其他 GPU 获取的参数分片。下图展示了在不同分片策略下,通信操作(AG 和 RS)与计算操作(前向 F、反向 B、优化器步骤 Opt step)的穿插执行方式。
我们已经支持在 M-Core 上使用 Torch FSDP2。M-Core Custom FSDP 是一个轻量级、面向性能的 FSDP 实现。
下表比较了 Torch FSDP2 和我们的 Custom FSDP:
| 特性 | FSDP2 (Torch) | Custom FSDP |
|---|---|---|
| ZeRO-DP 等效性 | ✅ | ✅ |
| 可扩展性 | ✅ | ✅✅ |
| 性能 | ✅ | ✅✅ |
| 内存管理 | ✅ | ✅✅ (可选的双缓冲分配器) |
| CPU 卸载 | ✅ | ❌ |
| 通信定制 | ❌ | ✅ (可选的基于用户缓冲区注册的 NCCL 通信) |
| FP8 训练 | ✅ | ✅✅ (原生 FP8 支持;更好的转置缓存管理;NCCL 组调用) |
| M-Core MoE 训练兼容 | ❌ | ✅ |
下图展示了 Custom FSDP 的核心组件及其交互关系。
FullyShardedDataParallel:顶层模块,负责动态删除 fp8 转置缓存,并在前向/后向 hook 中添加必要的原生 fp8 量化过程。AllGatherPipeline / GradReducePipeline:实现计算与通信的重叠,并使用 NCCL 组调用来提高效率。ParamAndGradBuffer:管理参数和梯度的缓冲区。DataParallelBuffer / TemporaryBucketAllocator:支持基于用户缓冲区注册的 NCCL 通信,并使用双缓冲分配器。torch.nn.Module:通过 Per-bucket NCCL communication: ZeRO-Copy 进行参数桶映射。通信流处理器(SM)对 Gemm 性能的影响
当 Gemm(通用矩阵乘法)和 Comm(通信)操作并发运行时,它们必须共享资源(SM、DRAM、L2 缓存等),这会导致性能下降。
Gemm 在与 Comm 重叠时吞吐量较低的原因:
LLM 训练中的 Gemm 对 SM 干扰比内存干扰更敏感。
下图展示了 Gemm 性能在减少可用内存带宽(左)和减少 SM 数量(右)时的下降情况。
利用 NVL Sharp、IB-Sharp 和 NCCL 实现高效集合通信
下图展示了在 256 个 GPU(8x32)上,不同 NCCL 算法在 All-Gather (AG) 和 Reduce-Scatter (RS) 操作中的带宽表现。新的 NVLS+IB/S 算法(黑色曲线)使用极少的 CTA(4个)即可达到高带宽。
*预计在 Megatron-Core v0.13 中发布。
什么是 UBR?
Torch 原生 AllGather:
基于 UBR 的 AllGather:
目前,基于用户缓冲区注册(UBR)的 NCCL 通信只能由带有双缓冲内存分配器的 Custom FSDP 支持。
用于加速多桶通信的 NCCL 组调用
示例:单个 FSDP 单元被拆分为三个通信组。
bfloat16, float8)和模块的不同,参数被分到了多个通信组(group 14, 15, 16, 17)中,这会降低通信效率。
上图对比了未使用(上)和使用(下)聚合 NCCL 通信的训练过程。在不使用聚合通信时,每个 FSDP 单元会为 FP8 和 BF16 的桶(bucket)分别触发大量独立的 NCCL 核心。使用聚合 NCCL 通信后,每个 FSDP 单元仅启动一个 NCCL 核心,显著减少了核心启动的开销,从而将通信操作融合,提高了效率。
前向传播计算量 (V_comp):
V_comp = 2 x (3h^2+4h^2+4h^2)/tp x SL/cp x mbs,其中 2 表示每次参数操作对应 2 个浮点运算(fma ops)。通信量 (V_comm):
V_comm = 2 x (3h^2+4h^2+4h^2)/tp,其中 2 表示参数存储于 bf16 格式。完全隐藏通信的条件:
t_comp > t_comm,即计算时间大于通信时间。tokens_per_micro_batch_per_gpu > (achieved Flops / achieved BW)变量定义:
h: 隐藏层维度 (hidden dimension)SL: 序列长度 (sequence length)mbs: 微批次大小 (micro batch size)tp: 张量并行大小 (tensor parallel size)cp: 上下文并行大小 (context parallel size)此页在上一页的分析基础上,增加了一个表格,展示了在不同张量并行(TP)规模下,为隐藏通信所需的最小微批次 token 数。
结论: 张量并行(TP)通信会与 FSDP 的通信竞争 NVLink 带宽。因此,我们需要调整 TP 和 FSDP 的规模以最大化性能。
数据表分析:
后向传播计算量 (V_comp):
V_comp = 4 x (3h^2+4h^2+4h^2)/tp x SL/cp x mbs,其中 4 表示每次参数操作对应 4 个浮点运算(fma ops)。通信量 (V_comm):
V_comm = (2+4) x (3h^2+4h^2+4h^2)/tp,其中 2 和 4 分别表示参数和梯度存储于 bf16 和 fp32 精度。完全隐藏通信的条件:
tokens_per_micro_batch_per_gpu > (3/2 * achieved Flops / achieved BW)数据表分析:
结论:
后向传播计算量 (V_comp):
V_comp = 4 x (4h^2+4h^2)/tp x SL/cp x mbs x topK,其中 4 表示 fma ops。通信量 (V_comm):
V_comm = (2+4) x (4h^2+4h^2)/tp x N/ep,其中 2 和 4 分别表示参数和梯度存储于 bf16 和 fp32 精度。完全隐藏通信的条件:
tokens_per_micro_batch_per_gpu > (3/2 * N/ep * 1/topK * achieved Flops / achieved BW)结论:
topK 值,可以大幅降低 FSDP 中的通信与计算比率,从而使 FSDP 中的通信能被很好地重叠(overlapped)。新增变量:
topK: 每个 token 被分派到的专家数量N: 专家总数本页总结了如何调整并行策略以更好地隐藏通信成本。
判据: tokens_per_micro_batch_per_gpu > (3/2 * N/ep * 1/topK * achieved Flops / achieved BW)
FSDP 通信暴露较少的条件:
topK 值该基准测试在 2 个节点、16xH100 GPU 上进行。
此图展示了在 Megatron-Core 上训练 LLaMA3-70B 时,自定义 FSDP (C-FSDP) 与 FSDP2 的扩展效率对比。
此图比较了自定义 FSDP (C-FSDP) 与 3D 并行在 H100 系统上的性能和内存使用情况。
性能 (TFLOPS/GPU): C-FSDP 的性能与 3D 并行相当。
内存使用 (GB): C-FSDP 在所有测试场景下都比 3D 并行占用更少的内存。
结论: 在经典工作负载上,自定义 FSDP 的性能与 3D 并行相当,并具有内存节省的优势。
mbs x seq-length(微批次大小 x 序列长度),增加了单个核心的计算负载,从而避免了因核心启动过快而产生的气泡问题。Llama3-8B FP8, CP2DP4 的 timeline 显示了大量细碎的计算核心。Llama3-8B FP8, FSDP 的 timeline 显示了更大、更连续的计算核心,有效地缓解了 CPU 启动瓶颈。FSDP64(64 位优化器状态)且不带激活重计算(activation recompute)时,发生内存溢出(OOM)。内存分析:
根本原因: 巨大的激活内存导致显存耗尽。堆栈跟踪信息显示,内存分配峰值发生在 transformer 的前向传播过程中。
FSDP64 + CP8(上下文并行度为8)的配置成功运行。结果:
内存分析: 通过上下文并行,激活内存(Activation Memory)被分散到多个 GPU 上,单个 GPU 的激活内存占用降低到约 25GB,从而避免了 OOM。下图展示了使用 FSDP64 + CP4 配置时的内存使用快照,激活内存峰值仍超过50GB,说明调整CP值对内存控制至关重要。
本节通过表格和性能分析器视图,对比了采用上下文并行(CP)与采用激活重计算(Activation Recompute)两种方案的性能差异。
实验设置:
- 模型: Llama3-70b
- 计算精度: BF16
- GPU数量: 64
- 序列长度 (SEQ_LEN): 8194
- FSDP 并行度: 64
- 张量并行度 (TP): 1
- 流水线并行度 (PP): 1
- 微批次大小 (MBS): 1
- 全局批次大小 (GBS): 128
对比结果:
下表对比了三种不同配置下训练 Llama3-70B 的性能指标。
| 配置 | 激活完全重计算 | 内存使用 (GB) | TFLOPS/GPU |
|---|---|---|---|
| CP=2 + 激活重计算 | 是 | 64.62 | 427 |
| CP=8 (无重计算) | 否 | 79.05 | 226 |
| CP=4 (无重计算) | 否 | 79.05 | 379 |
性能分析:
- 上图 (FSDP + CP4): 时间线显示计算(Kernel)和通信(蓝色条块)之间存在明显的空闲间隙。标注指出:“计算无法隐藏通信”,这意味着 GPU 在等待通信完成,导致计算单元闲置,降低了效率。
- 下图 (FSDP + 激活重计算): 计算流(Compute Stream)变得非常繁忙和密集。标注指出:“完全隐藏通信”和“繁忙的计算流”。这表明通过重计算引入的额外计算任务成功地与通信操作重叠,使得 GPU 能够持续执行计算任务,从而提高了整体的吞吐量。
结论:
在此场景下,激活重计算是最佳的映射方案。尽管激活重计算增加了约 1/3 的计算量,但通过更好地重叠计算和通信,获得了更高的整体性能,同时显著降低了内存使用。
本页总结了在训练大模型时的一些最佳实践建议。
| 实践领域 | 建议 |
|---|---|
| 计算与通信重叠 | 使用大的序列长度/微批次大小;调整通信桶(communication bucket)大小和 SM 分配。 |
| 激活内存优化 | 根据需要使用上下文并行(context parallelism)、激活卸载(activation offloading)或选择性激活重计算(selective activation recompute)。 |
| 权衡:内存节省 vs. 通信 | 仅在需要时启用重计算;调整重计算的粒度;优先考虑重叠以提高吞吐量。 |
本页提供了在 Megatron-Core 框架中启用和配置自定义 FSDP(Fully Sharded Data Parallelism)功能的具体指南。
文档链接:
https://github.com/NVIDIA/Megatron-LM/blob/main/docs/source/api-guide/custom_fsdp.md如何在 Megatron-Core 中启用自定义 FSDP?
--use-custom-fsdp--data-parallel-sharding-strategy optim_grads_params--no-gradient-accumulation-fusion # 注:自定义FSDP(c-fsdp)目前不支持 TE(TransformerEngine)的梯度累积融合。可选标志 (Optional Flags):
--calculate-per-token-loss # 用于与普通数据并行获得更好的数值对齐。--init-model-with-meta-device # 在初始化巨大模型时避免 OOM(内存不足)的必要配置。